🦙 Introduction

The pictionary dataset provided was designed for computational models to attempt to classify (there are only 6 classification words) what word is drawn in a 28 x 28 pixel gray scale image space, much like the popular image recognition game pictionary.There are 784 variables which represent the 28 x 28 grid that is a representation of the input drawing which we attempt to classify using a model. There are 6,000 training observations and 1,200 test observations which require prediction. Each observation represents one drawing grid and which of the 6 word categories it belongs to (each observation must belong to one of the 6 words) as shown below.

🔧 Methodology

Initial Testing

In the initial phase, we tested out RandomForest and SVM on the ‘sketches’ data but the accuracy rate was 80% and 82% respectively. Seeking higher accuracy, we approached Convolutional Neural Network(CNN) which is a deeper neural network that has consistently performed well on image recognition in the annual imagenet competition. Table 1 shows the results after testing several architectures such as LENET-5 (Dasaprakash, 2019), VGG16 (Deshmukh, 2018), RESNET50 (Rosebrock, 2017) and 4 other custom architectures modified from these 3 architectures.

Data Preparation For CNN

In order for neural networks to operate, there are some necessary data cleaning steps required. Firstly, the ‘sketches’ data needs to have numeric inputs, hence all 6 classifications are converted to a number from 0 to 5 in rank of alphabetical order. Next, neural networks are sensitive to the scale of the feature values hence all values from variables 1 to 784 are divided by 255 (maximum value) to standardize values to be between 0 to 1. Since our ‘sketches’ data is categorical, the labels are required to be in binary matrix and this is done by one hot encoding the numerical labels with the to_categorical() function from Keras. Lastly, convolutional neural networks require the observation data to be in 4 dimensions, therefore, using the array_reshape() function, the ‘sketches’ variables were reshaped into 4 dimension data.

Best CNN Model Architecture

The highest validation and training accuracy came from our custom 1 (Govoruha, 2019) architecture (figure 1) which has 7 layers and adopts the initial layering structure of VGG16 with the addition of batch normalization and dropout. The architecture of custom 1 is built based on the VGG16 structure, however, instead of going deeper with more 3x3 convolutional layers, a 8x8 kernel is used to increase the receptive field and incorporate more information (Ghosh, 2017). The process of CNN is illustrated in figure 2.

In the custom 1 architecture, a layer of batch normalization is added before the max pooling layer which normalizes the output of the previous activation layer. This method effectively reprocesses each layer of the network and in turn increases the stability of the neural network (FD, 2017). After the max pooling layer, a dropout layer is incorporated to simulate having a large number of different network architectures, by randomly dropping out nodes during training. This is a regularization method is used to reduce over fitting and improve generalization error which is especially effective for relatively small datasets such as the ‘sketches’ data (Brownlee, 2019).

Given that our data is a multi-class classification, the Keras categorical cross-entropy is the most suitable loss function to compile the model (James, Witten, Hastie & Tibshirani, 2013). Next, we chose the ‘adam’ (adaptive moment estimation) optimizer algorithm because it combines the advantages of both ‘AdaGrad’ and ‘RMSProp’ (Brownlee, 2017), resulting in an optimizer that is efficient and highly adaptive to the learning rate. During the model training, the weights and biases are updated using the ‘adam’ optimizer to minimize the categorical cross-entropy loss function (Wang, 2018).

Data Augmentation

With only 6,000 samples (compared to 60,000 MNIST), data augmentation is applied to the model training by generating our own data from the existing ‘sketches’ data (Brownlee, 2019). This method overcomes the problem of having a limited quantity of data by randomly zooming, rotating and shifting the images (illustrated in image 1) to allow the model to be trained on a variation of orientations and scales, which in turn increases the recognition capability of the model (Gandhi, 2018). The incorporation of data augmentation in the model training gave an average of 1.2% boost in validation accuracy.

Voting Ensemble

Convolutional Neural Networks learn through a stochastic training algorithm and as a result they can be sensitive to the training data and hold different weights in each training, thus producing different predictions for the same ‘sketches_test’ data (Brownlee, 2019). This variance can be reduced with the use of voting ensemble, an ensemble technique whereby the class is determined with the majority vote (Brownlee, 2020). In order to perform a hard voting ensemble, we trained 6 models of custom 1 using random splits of the ‘sketches’ data in each model training and for each observation, and the class with the highest number of votes will give the overall prediction (Brownlee, 2020). The application of voting ensemble improved the accuracy by approximately 0.8%.

🌱 Results and Discussion

Figure 3 shows how each filter in the top two convolutional layers are activating in response to a single image of a kangaroo from the data. The filters in the top layer seem to be picking up on low level features such as left and right edges, and the general body shape

The second layer has much less of the original image activated in each filter, which indicates that the layer is picking up on higher-level features. It seems that in some cases the layer is capturing the extremities of the picture, such as ears, arms, and the tail, and in some filters the image of the kangaroo has been broken down into a more basic form. There are also many filters with more unintelligible outputs, and some that do not activate at all.

Figure 4 shows test/training loss plotted against epoch and Figure 5 shows test/training accuracy plotted against epoch. Note that the figures above refer to one of the six runs used to build our model, however each of the 6 models test/training loss/accuracy on epoch graphs were very similar as they are all based on the same architecture. As shown in Figure 5 the test accuracy increases as the epoch increases, until approximately 20, where test accuracy begins to become roughly constant, varying between 90 and 95%. The callback_model_checkpoint() function was used to capture the model with the highest validation accuracy and all 6 models were captured between 25-50 epochs.

Table 2 shows the training accuracy of each individual network before stacking. Our ensemble of networks managed to outperform each individual model with an accuracy of 0.980 from an average of 0.973, suggesting our voting method has successfully decreased model variance.

Figure 6 shows the images which were classifieds by the model. It is clear why the model struggled to classify them correctly! After inspecting all individual sketches, we believe that most of them are either too abstract for any human to recognize even or the sketch can be very easily misclassified to another object (such as banana and boomerang).

Table 3 shows the ‘confusion matrix’ for our model. The confusion matrix depicts which categories the model is ‘confusing’ for other categories. True predictions are counted down the diagonal, whereas wrongly predicted categories are counted in the off diagonals based on which images were ‘confused’ for which incorrect category (Narkhede, 2018).

The categories predicted by the model correctly the most were crab and flip flops, both categories only predicting 8 images wrong out of 1000. This is likely due to crab and flip flops being quite distinct from the other image categories. This was followed by cactus at 11 incorrect, kangaroo and banana at 17 incorrect and finally boomerang at 58 incorrect. This suggests that the model was able to predict all of the categories except for boomerang reasonably well.

One would expect that as bananas and boomerangs look quite similar, the model may find differentiating these categories difficult. This is evident in the confusion matrix with 35 boomerang images misclassified as bananas. Interestingly, this effect is not reciprocated for bananas, with only 4 banana images being misclassified as boomerangs. It is important to note that there are already several features that can distinguish a banana from a boomerang, such as an opened peel, being drawn from a three-dimensional perspective or having a sharp tip representing the stem. However boomerangs were more likely to be drawn as the archetypal “two-dimensional curve” making misclassification a greater problem.

To investigate this further, figure 7 displays the 35 boomerang images that were misclassified as bananas. A common feature in this set of images are ‘boomerangs’ with a sharp point, on either side (e.g. image 1, 2, 3, 4) or just on one side (e.g. image 5, 15, 16, 17).

Figure 8 shows a sample of 35 correctly classified cases for boomerangs (left) and bananas (right). In general, boomerangs were mostly drawn with curved edges, whereas bananas were mostly drawn with sharp edges. It is likely that the model misclassified the boomerangs in figure 7 as they displayed the sharp point characteristic common to the majority of the banana images. Furthermore this suggests that bananas drawn almost always had a pointed edge, due to only 4 bananas being misclassified as boomerangs.

🥮 Conclusion

In conclusion, we learnt that CNN models are a great classification model to work with images. To get the best out of our model, we needed to counteract the sensitivity that existed between the model and the training data, so we employed a voting-ensemble that reduced the variance. Overall the use of the CNN, depth layers and voting-ensemble drastically improved our accuracy compared to the originally tested SVM and random forest classification models. While exploring the data we also drew upon a lot of interesting observations about the data and reasons why our model had some consistent difficulty classifying some sketches. Some sketches were too abstract or too similar to the features of a different category, which caused misclassification and accounted for most of our error. In the future we would like to draw upon the interesting points collected in our discussion to perhaps implement some more layering to improve our CNN model for the pictionary dataset, and thus hopefully improve the models ability to classify all future sketches drawn.

🍎 References

Sources

Brownlee, J. (2019, December 3). A Gentle Introduction to Dropout for Regularizing Deep Neural Networks. Retrieved from
https://machinelearningmastery.com/dropout-for-regularizing-deep-neural-networks/

Brownlee, J. (2018, December 19). Ensemble Learning Methods for Deep Learning Neural Networks. Retrieved from https://machinelearningmastery.com/ensemble-methods-for-deep-learning-neural-networks/

Brownlee, J. (2017 July 3). Gentle Introduction to the Adam Optimization Algorithm for Deep Learning. Retrieved from
https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/

Brownlee, J. (2019, April 12). How to Configure Image Data Augmentation in Keras. Retrieved from
https://machinelearningmastery.com/how-to-configure-image-data-augmentation-when-training-deep-learning-neural-networks/

Brownlee, J. (2020, April 17). How to Develop Voting Ensembles With Python. Retrieved from
https://machinelearningmastery.com/voting-ensembles-with-python/

Brownlee, J. (2019, August 6).How To Improve Deep Learning Performance. Retrieved from
https://machinelearningmastery.com/improve-deep-learning-performance/

Dasaprakash, K. (2019, February 19). LeNet-5 CNN with Keras - 99.48%. Retrieved from
https://www.kaggle.com/curiousprogrammer/lenet-5-cnn-with-keras-99-48/input#LeNet-5-CNN-with-Keras:

Deotte, C. (2020 February 18). 25 Million Images! [0.99757] MNIST. Retrieved from
https://www.kaggle.com/cdeotte/25-million-images-0-99757-mnist

Deshmukh, A. (2018, May 8). Classify Fashion_Mnist with VGG16. Retrieved from
https://www.kaggle.com/anandad/classify-fashion-mnist-with-vgg16/input

Falbel, D., Allaire, J. J., Chollet, F., RStudio & Google. (2020, May 30). Tutorial: Save and Restore Models. Retrieved from
https://keras.rstudio.com/articles/tutorial_save_and_restore.html

FD. (2017, October 21). Batch normalization in Neural Networks. Retrieved from
https://towardsdatascience.com/batch-normalization-in-neural-networks-1ac91516821c

Gandhi, A. (2018). Data Augmentation | How to use Deep Learning when you have Limited Data — Part 2. Retrieved from
https://nanonets.com/blog/data-augmentation-how-to-use-deep-learning-when-you-have-limited-data-part-2/

Ghouzam, Y. (2017, August 18). Introduction to CNN Keras - 0.997 (top 6%). Retrieved from
https://www.kaggle.com/yassineghouzam/introduction-to-cnn-keras-0-997-top-6/notebook

Ghosh, T. (2017, April 13). When is a large-sized kernel useful in CNN?. Retrieved from
https://www.quora.com/When-is-a-large-sized-kernel-useful-in-CNN

Gokirmak, G. M. (2018, November 27). Keras CNN multi model ensemble with voting. Retrieved from
https://www.kaggle.com/mgiraygokirmak/keras-cnn-multi-model-ensemble-with-voting/notebook

Govoruha, P. (2019, October 8). R Keras CNN. Retrieved from
https://www.kaggle.com/govoruha/r-keras-cnn/

James, G., Witten, D., Hastie, t., & Tibshirani, R. (2013). An Introduction to Statistical Learning with Applications in R. Retrieved from
http://faculty.marshall.usc.edu/gareth-james/ISL/ISLR%20Seventh%20Printing.pdf

LeCun, Y., Cortes, C. & Burges, J. C. C. (2020 May 30). THE MNIST DATABASE of handwritten digits. Retrieved from
http://yann.lecun.com/exdb/mnist/

Narkhede, S. (2018, May 9). Understanding Confusion Matrix. Retrieved from
https://towardsdatascience.com/understanding-confusion-matrix-a9ad42dcfd62

Pai, P. (2017, October 25). Data Augmentation Techniques in CNN using Tensorflow. Retrieved from
https://medium.com/@prasad.pai/data-augmentation-techniques-in-cnn-using-tensorflow-371ae43d5be9

Rosebrock, A. (2017, March 20). ImageNet: VGGNet, ResNet, Inception, and Xception with Keras. Retrieved from
https://www.pyimagesearch.com/2017/03/20/imagenet-vggnet-resnet-inception-xception-keras/

Shen, K. (2018, June 20). Effect of batch size on training dynamics. Retrieved from
https://medium.com/mini-distill/effect-of-batch-size-on-training-dynamics-21c14f7a716e

Thakur, R. (2019, August 6). Step by step VGG16 implementation in Keras for beginners. Retrieved from
https://towardsdatascience.com/step-by-step-vgg16-implementation-in-keras-for-beginners-a833c686ae6c

Tsang, SK. (2018, August 8). Review: LeNet-1, LeNet-4, LeNet-5, Boosted LeNet-4 (Image Classification). Retrieved from
https://medium.com/@sh.tsang/paper-brief-review-of-lenet-1-lenet-4-lenet-5-boosted-lenet-4-image-classification-1f5f809dbf17

Wang, CF. (2018, September 1). Finding the Cost Function of Neural Networks. Retrieved from
https://towardsdatascience.com/step-by-step-the-math-behind-neural-networks-490dc1f3cfd9


R Packages

Bache, M. S. & Wickham, H. (2014). magrittr: A Forward-Pipe Operator for R. R package version 1.5. Retrieved from
https://cran.r-project.org/web/packages/magrittr/index.html

Breiman, L., Cutler, A., Liaw, A. & Wiener, M. (2018). randomForest: Breiman and Cutler’s Random Forests for Classification and Regression. R package version 4.6-14. Retrieved from
https://cran.r-project.org/web/packages/randomForest/index.html

Dowle, M., Srinivasan, A., Gorecki, J., Chirico, M., Stetsenko, P., Short, T., Lianoglou, S., Antonyan, E., Bonsch, M., Parsonage, H., Ritchie, S., Ren, K., Tan, X., Saporta, R., Seiskari, o., Dong, X., Lang, M., Iwasaki, W., Wenchel, S., Broman, K., Schmidt, T., Arenburg, D., Smith E., Cocquemas, F., Gomez, M., Chataignon, P., Groves, D., Possenriede, D., Parages, F., Toth, D., Yaramaz-David, M., Perumal, A., Sams, J., Morgan, M., Quinn, M., [ctb], Storey, R., Saraswat, M., Jacob, M., Schubmehl, M. & Vaughan, D. (2019). data.table: Extension of ‘data.frame’. R package version 1.12.8. Retrieved from
https://cran.r-project.org/web/packages/data.table/index.html

Falbel, D., Allaire, J. J., Chollet, F., Tang, Y., Bijl, VD. B., Studer, M. & Keydana, S. (2020). keras: R Interface to ‘Keras’. R package version 2.3.0.0. Retrieved from
https://cran.r-project.org/web/packages/keras/index.html

Falbel, D., Allaire, J. J., RStudio, Tang, Y., Eddelbuettel, D., Golding, N.., Kalinowski, T. & Google Inc. (2020). tensorflow: R Interface to ‘TensorFlow’. R package version 2.2.0. Retrieved from
https://CRAN.R-project.org/package=tensorflow

Hamner, B., Frasco, M. & LeDell, E. (2018). Metrics: Evaluation Metrics for Machine Learning. R package version 0.1.4. Retrieved from
https://cran.r-project.org/web/packages/Metrics/index.html

Kuhn, M., Wing, J., Weston, S., Williams, A., Keefer, C., Engelhardt, A., Cooper, T., Mayer, Z., Kenkel, B., Benesty, M., Lescarbeau, R., Ziem, A., Scrucca, L., Tang, Y., Candan, C. & Hunt, T. (2020) caret: Classification and Regression Training. R package version 6.0-86. Retrieved from
https://cran.r-project.org/web/packages/caret/index.html

Wickham, H. (2019). tidyverse: Easily Install and Load the ‘Tidyverse’. R package version 1.3.0. Retrieved from
https://cran.r-project.org/web/packages/tidyverse/index.html